3 # https://raw.githubusercontent.com/hankcs/udacity-deep-learning/master/6_lstm.py
11 # After training a skip-gram model in `5_word2vec.ipynb`, the goal of this notebook is to train a LSTM character model over [Text8](http://mattmahoney.net/dc/textdata) data.
15 # These are all the modules we'll be using later. Make sure you can import them
16 # before proceeding further.
17 from __future__
import print_function
22 import tensorflow
as tf
24 from six
.moves
import range
25 from six
.moves
.urllib
.request
import urlretrieve
29 url
= 'http://mattmahoney.net/dc/'
32 def maybe_download(filename
, expected_bytes
):
33 """Download a file if not present, and make sure it's the right size."""
34 if not os
.path
.exists(filename
):
35 filename
, _
= urlretrieve(url
+ filename
, filename
)
36 statinfo
= os
.stat(filename
)
37 if statinfo
.st_size
== expected_bytes
:
38 print('Found and verified %s' % filename
)
40 print(statinfo
.st_size
)
42 'Failed to verify ' + filename
+ '. Can you get to it with a browser?')
46 filename
= maybe_download('text8.zip', 31344016)
51 def read_data(filename
):
52 f
= zipfile
.ZipFile(filename
)
53 for name
in f
.namelist():
54 return tf
.compat
.as_str(f
.read(name
))
58 text
= read_data(filename
)
59 print('Data size %d' % len(text
))
61 # Create a small validation set.
66 valid_text
= text
[:valid_size
]
67 train_text
= text
[valid_size
:]
68 train_size
= len(train_text
)
69 print(train_size
, train_text
[:64])
70 print(valid_size
, valid_text
[:64])
72 # Utility functions to map characters to vocabulary IDs and back.
76 vocabulary_size
= len(string
.ascii_lowercase
) + 1 # [a-z] + ' '
77 first_letter
= ord(string
.ascii_lowercase
[0])
81 if char
in string
.ascii_lowercase
:
82 return ord(char
) - first_letter
+ 1
86 print('Unexpected character: %s' % char
)
92 return chr(dictid
+ first_letter
- 1)
97 print(char2id('a'), char2id('z'), char2id(' '), char2id('ï'))
98 print(id2char(1), id2char(26), id2char(0))
100 # Function to generate a training batch for the LSTM model.
108 class BatchGenerator(object):
109 def __init__(self
, text
, batch_size
, num_unrollings
):
111 self
._text
_size
= len(text
)
112 self
._batch
_size
= batch_size
113 self
._num
_unrollings
= num_unrollings
114 segment
= self
._text
_size
// batch_size
115 self
._cursor
= [offset
* segment
for offset
in range(batch_size
)]
116 self
._last
_batch
= self
._next
_batch
()
118 def _next_batch(self
):
119 """Generate a single batch from the current cursor position in the data."""
120 batch
= np
.zeros(shape
=(self
._batch
_size
, vocabulary_size
), dtype
=np
.float)
121 for b
in range(self
._batch
_size
):
122 batch
[b
, char2id(self
._text
[self
._cursor
[b
]])] = 1.0
123 self
._cursor
[b
] = (self
._cursor
[b
] + 1) % self
._text
_size
127 """Generate the next array of batches from the data. The array consists of
128 the last batch of the previous array, followed by num_unrollings new ones.
130 batches
= [self
._last
_batch
]
131 for step
in range(self
._num
_unrollings
):
132 batches
.append(self
._next
_batch
())
133 self
._last
_batch
= batches
[-1]
137 def characters(probabilities
):
138 """Turn a 1-hot encoding or a probability distribution over the possible
139 characters back into its (most likely) character representation."""
140 return [id2char(c
) for c
in np
.argmax(probabilities
, 1)]
143 def batches2string(batches
):
144 """Convert a sequence of batches back into their (most likely) string
146 s
= [''] * batches
[0].shape
[0]
148 s
= [''.join(x
) for x
in zip(s
, characters(b
))]
152 train_batches
= BatchGenerator(train_text
, batch_size
, num_unrollings
)
153 valid_batches
= BatchGenerator(valid_text
, 1, 1)
155 print(batches2string(train_batches
.next()))
156 print(batches2string(train_batches
.next()))
157 print(batches2string(valid_batches
.next()))
158 print(batches2string(valid_batches
.next()))
163 def logprob(predictions
, labels
):
164 """Log-probability of the true labels in a predicted batch."""
165 predictions
[predictions
< 1e-10] = 1e-10
166 return np
.sum(np
.multiply(labels
, -np
.log(predictions
))) / labels
.shape
[0]
169 def sample_distribution(distribution
):
170 """Sample one element from a distribution assumed to be an array of normalized
173 r
= random
.uniform(0, 1)
175 for i
in range(len(distribution
)):
179 return len(distribution
) - 1
182 def sample(prediction
):
183 """Turn a (column) prediction into 1-hot encoded samples."""
184 p
= np
.zeros(shape
=[1, vocabulary_size
], dtype
=np
.float)
185 p
[0, sample_distribution(prediction
[0])] = 1.0
189 def random_distribution():
190 """Generate a random column of probabilities."""
191 b
= np
.random
.uniform(0.0, 1.0, size
=[1, vocabulary_size
])
192 return b
/ np
.sum(b
, 1)[:, None]
202 with graph
.as_default():
204 # Input gate: input, previous output, and bias.
205 ix
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
206 im
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
207 ib
= tf
.Variable(tf
.zeros([1, num_nodes
]))
208 # Forget gate: input, previous output, and bias.
209 fx
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
210 fm
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
211 fb
= tf
.Variable(tf
.zeros([1, num_nodes
]))
212 # Memory cell: input, state and bias.
213 cx
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
214 cm
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
215 cb
= tf
.Variable(tf
.zeros([1, num_nodes
]))
216 # Output gate: input, previous output, and bias.
217 ox
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
218 om
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
219 ob
= tf
.Variable(tf
.zeros([1, num_nodes
]))
220 # Variables saving state across unrollings.
221 saved_output
= tf
.Variable(tf
.zeros([batch_size
, num_nodes
]), trainable
=False)
222 saved_state
= tf
.Variable(tf
.zeros([batch_size
, num_nodes
]), trainable
=False)
223 # Classifier weights and biases.
224 w
= tf
.Variable(tf
.truncated_normal([num_nodes
, vocabulary_size
], -0.1, 0.1))
225 b
= tf
.Variable(tf
.zeros([vocabulary_size
]))
228 # Definition of the cell computation.
229 def lstm_cell(i
, o
, state
):
230 """Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
231 Note that in this formulation, we omit the various connections between the
232 previous state and the gates."""
233 input_gate
= tf
.sigmoid(tf
.matmul(i
, ix
) + tf
.matmul(o
, im
) + ib
)
234 forget_gate
= tf
.sigmoid(tf
.matmul(i
, fx
) + tf
.matmul(o
, fm
) + fb
)
235 update
= tf
.matmul(i
, cx
) + tf
.matmul(o
, cm
) + cb
236 state
= forget_gate
* state
+ input_gate
* tf
.tanh(update
)
237 output_gate
= tf
.sigmoid(tf
.matmul(i
, ox
) + tf
.matmul(o
, om
) + ob
)
238 return output_gate
* tf
.tanh(state
), state
243 for _
in range(num_unrollings
+ 1):
245 tf
.placeholder(tf
.float32
, shape
=[batch_size
, vocabulary_size
]))
246 train_inputs
= train_data
[:num_unrollings
]
247 train_labels
= train_data
[1:] # labels are inputs shifted by one time step.
249 # Unrolled LSTM loop.
251 output
= saved_output
253 for i
in train_inputs
:
254 output
, state
= lstm_cell(i
, output
, state
)
255 outputs
.append(output
)
257 # State saving across unrollings.
258 with tf
.control_dependencies([saved_output
.assign(output
),
259 saved_state
.assign(state
)]):
261 logits
= tf
.nn
.xw_plus_b(tf
.concat(0, outputs
), w
, b
)
262 loss
= tf
.reduce_mean(
263 tf
.nn
.softmax_cross_entropy_with_logits(
264 logits
, tf
.concat(0, train_labels
)))
267 global_step
= tf
.Variable(0)
268 learning_rate
= tf
.train
.exponential_decay(
269 10.0, global_step
, 5000, 0.1, staircase
=True)
270 optimizer
= tf
.train
.GradientDescentOptimizer(learning_rate
)
271 gradients
, v
= zip(*optimizer
.compute_gradients(loss
))
272 gradients
, _
= tf
.clip_by_global_norm(gradients
, 1.25)
273 optimizer
= optimizer
.apply_gradients(
274 zip(gradients
, v
), global_step
=global_step
)
277 train_prediction
= tf
.nn
.softmax(logits
)
279 # Sampling and validation eval: batch 1, no unrolling.
280 sample_input
= tf
.placeholder(tf
.float32
, shape
=[1, vocabulary_size
])
281 saved_sample_output
= tf
.Variable(tf
.zeros([1, num_nodes
]))
282 saved_sample_state
= tf
.Variable(tf
.zeros([1, num_nodes
]))
283 reset_sample_state
= tf
.group(
284 saved_sample_output
.assign(tf
.zeros([1, num_nodes
])),
285 saved_sample_state
.assign(tf
.zeros([1, num_nodes
])))
286 sample_output
, sample_state
= lstm_cell(
287 sample_input
, saved_sample_output
, saved_sample_state
)
288 with tf
.control_dependencies([saved_sample_output
.assign(sample_output
),
289 saved_sample_state
.assign(sample_state
)]):
290 sample_prediction
= tf
.nn
.softmax(tf
.nn
.xw_plus_b(sample_output
, w
, b
))
295 summary_frequency
= 100
297 with tf
.Session(graph
=graph
) as session
:
298 tf
.initialize_all_variables().run()
301 for step
in range(num_steps
):
302 batches
= train_batches
.next()
304 for i
in range(num_unrollings
+ 1):
305 feed_dict
[train_data
[i
]] = batches
[i
]
306 _
, l
, predictions
, lr
= session
.run(
307 [optimizer
, loss
, train_prediction
, learning_rate
], feed_dict
=feed_dict
)
309 if step
% summary_frequency
== 0:
311 mean_loss
= mean_loss
/ summary_frequency
312 # The mean loss is an estimate of the loss over the last few batches.
314 'Average loss at step %d: %f learning rate: %f' % (step
, mean_loss
, lr
))
316 labels
= np
.concatenate(list(batches
)[1:])
317 print('Minibatch perplexity: %.2f' % float(
318 np
.exp(logprob(predictions
, labels
))))
319 if step
% (summary_frequency
* 10) == 0:
320 # Generate some samples.
323 feed
= sample(random_distribution())
324 sentence
= characters(feed
)[0]
325 reset_sample_state
.run()
327 prediction
= sample_prediction
.eval({sample_input
: feed
})
328 feed
= sample(prediction
)
329 sentence
+= characters(feed
)[0]
332 # Measure validation set perplexity.
333 reset_sample_state
.run()
335 for _
in range(valid_size
):
336 b
= valid_batches
.next()
337 predictions
= sample_prediction
.eval({sample_input
: b
[0]})
338 valid_logprob
= valid_logprob
+ logprob(predictions
, b
[1])
339 print('Validation set perplexity: %.2f' % float(np
.exp(
340 valid_logprob
/ valid_size
)))
346 # You might have noticed that the definition of the LSTM cell involves 4 matrix multiplications with the input, and 4 matrix multiplications with the output. Simplify the expression by using a single matrix multiply for each, and variables that are 4 times larger.
352 with graph
.as_default():
354 # Input gate: input, previous output, and bias.
355 ix
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
356 im
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
357 ib
= tf
.Variable(tf
.zeros([1, num_nodes
]))
358 # Forget gate: input, previous output, and bias.
359 fx
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
360 fm
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
361 fb
= tf
.Variable(tf
.zeros([1, num_nodes
]))
362 # Memory cell: input, state and bias.
363 cx
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
364 cm
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
365 cb
= tf
.Variable(tf
.zeros([1, num_nodes
]))
366 # Output gate: input, previous output, and bias.
367 ox
= tf
.Variable(tf
.truncated_normal([vocabulary_size
, num_nodes
], -0.1, 0.1))
368 om
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
], -0.1, 0.1))
369 ob
= tf
.Variable(tf
.zeros([1, num_nodes
]))
370 # Concatenate parameters
371 sx
= tf
.concat(1, [ix
, fx
, cx
, ox
])
372 sm
= tf
.concat(1, [im
, fm
, cm
, om
])
373 sb
= tf
.concat(1, [ib
, fb
, cb
, ob
])
374 # Variables saving state across unrollings.
375 saved_output
= tf
.Variable(tf
.zeros([batch_size
, num_nodes
]), trainable
=False)
376 saved_state
= tf
.Variable(tf
.zeros([batch_size
, num_nodes
]), trainable
=False)
377 # Classifier weights and biases.
378 w
= tf
.Variable(tf
.truncated_normal([num_nodes
, vocabulary_size
], -0.1, 0.1))
379 b
= tf
.Variable(tf
.zeros([vocabulary_size
]))
382 # Definition of the cell computation.
383 def lstm_cell(i
, o
, state
):
384 """Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
385 Note that in this formulation, we omit the various connections between the
386 previous state and the gates."""
387 y
= tf
.matmul(i
, sx
) + tf
.matmul(o
, sm
) + sb
388 y_input
, y_forget
, update
, y_output
= tf
.split(1, 4, y
)
389 input_gate
= tf
.sigmoid(y_input
)
390 forget_gate
= tf
.sigmoid(y_forget
)
391 output_gate
= tf
.sigmoid(y_output
)
392 state
= forget_gate
* state
+ input_gate
* tf
.tanh(update
)
393 return output_gate
* tf
.tanh(state
), state
398 for _
in range(num_unrollings
+ 1):
400 tf
.placeholder(tf
.float32
, shape
=[batch_size
, vocabulary_size
]))
401 train_inputs
= train_data
[:num_unrollings
]
402 train_labels
= train_data
[1:] # labels are inputs shifted by one time step.
404 # Unrolled LSTM loop.
406 output
= saved_output
408 for i
in train_inputs
:
409 output
, state
= lstm_cell(i
, output
, state
)
410 outputs
.append(output
)
412 # State saving across unrollings.
413 with tf
.control_dependencies([saved_output
.assign(output
),
414 saved_state
.assign(state
)]):
416 logits
= tf
.nn
.xw_plus_b(tf
.concat(0, outputs
), w
, b
)
417 loss
= tf
.reduce_mean(
418 tf
.nn
.softmax_cross_entropy_with_logits(
419 logits
, tf
.concat(0, train_labels
)))
422 global_step
= tf
.Variable(0)
423 learning_rate
= tf
.train
.exponential_decay(
424 10.0, global_step
, 5000, 0.1, staircase
=True)
425 optimizer
= tf
.train
.GradientDescentOptimizer(learning_rate
)
426 gradients
, v
= zip(*optimizer
.compute_gradients(loss
))
427 gradients
, _
= tf
.clip_by_global_norm(gradients
, 1.25)
428 optimizer
= optimizer
.apply_gradients(
429 zip(gradients
, v
), global_step
=global_step
)
432 train_prediction
= tf
.nn
.softmax(logits
)
434 # Sampling and validation eval: batch 1, no unrolling.
435 sample_input
= tf
.placeholder(tf
.float32
, shape
=[1, vocabulary_size
])
436 saved_sample_output
= tf
.Variable(tf
.zeros([1, num_nodes
]))
437 saved_sample_state
= tf
.Variable(tf
.zeros([1, num_nodes
]))
438 reset_sample_state
= tf
.group(
439 saved_sample_output
.assign(tf
.zeros([1, num_nodes
])),
440 saved_sample_state
.assign(tf
.zeros([1, num_nodes
])))
441 sample_output
, sample_state
= lstm_cell(
442 sample_input
, saved_sample_output
, saved_sample_state
)
443 with tf
.control_dependencies([saved_sample_output
.assign(sample_output
),
444 saved_sample_state
.assign(sample_state
)]):
445 sample_prediction
= tf
.nn
.softmax(tf
.nn
.xw_plus_b(sample_output
, w
, b
))
448 summary_frequency
= 100
450 with tf
.Session(graph
=graph
) as session
:
451 tf
.initialize_all_variables().run()
454 for step
in range(num_steps
):
455 batches
= train_batches
.next()
457 for i
in range(num_unrollings
+ 1):
458 feed_dict
[train_data
[i
]] = batches
[i
]
459 _
, l
, predictions
, lr
= session
.run(
460 [optimizer
, loss
, train_prediction
, learning_rate
], feed_dict
=feed_dict
)
462 if step
% summary_frequency
== 0:
464 mean_loss
= mean_loss
/ summary_frequency
465 # The mean loss is an estimate of the loss over the last few batches.
467 'Average loss at step %d: %f learning rate: %f' % (step
, mean_loss
, lr
))
469 labels
= np
.concatenate(list(batches
)[1:])
470 print('Minibatch perplexity: %.2f' % float(
471 np
.exp(logprob(predictions
, labels
))))
472 if step
% (summary_frequency
* 10) == 0:
473 # Generate some samples.
476 feed
= sample(random_distribution())
477 sentence
= characters(feed
)[0]
478 reset_sample_state
.run()
480 prediction
= sample_prediction
.eval({sample_input
: feed
})
481 feed
= sample(prediction
)
482 sentence
+= characters(feed
)[0]
485 # Measure validation set perplexity.
486 reset_sample_state
.run()
488 for _
in range(valid_size
):
489 b
= valid_batches
.next()
490 predictions
= sample_prediction
.eval({sample_input
: b
[0]})
491 valid_logprob
= valid_logprob
+ logprob(predictions
, b
[1])
492 print('Validation set perplexity: %.2f' % float(np
.exp(
493 valid_logprob
/ valid_size
)))
498 # We want to train a LSTM over bigrams, that is pairs of consecutive characters like 'ab' instead of single characters like 'a'. Since the number of possible bigrams is large, feeding them directly to the LSTM using 1-hot encodings will lead to a very sparse representation that is very wasteful computationally.
500 # a- Introduce an embedding lookup on the inputs, and feed the embeddings to the LSTM cell instead of the inputs themselves.
502 # b- Write a bigram-based LSTM, modeled on the character LSTM above.
504 # c- Introduce Dropout. For best practices on how to use Dropout in LSTMs, refer to this [article](http://arxiv.org/abs/1409.2329).
507 bigram_vocabulary_size
= vocabulary_size
* vocabulary_size
510 class BigramBatchGenerator(object):
511 def __init__(self
, text
, batch_size
, num_unrollings
):
513 self
._text
_size
_in
_chars
= len(text
)
514 self
._text
_size
= self
._text
_size
_in
_chars
// 2
515 self
._batch
_size
= batch_size
516 self
._num
_unrollings
= num_unrollings
517 segment
= self
._text
_size
// batch_size
518 self
._cursor
= [offset
* segment
for offset
in range(batch_size
)]
519 self
._last
_batch
= self
._next
_batch
()
521 def _next_batch(self
):
522 batch
= np
.zeros(shape
=self
._batch
_size
, dtype
=np
.int)
523 for b
in range(self
._batch
_size
):
524 char_idx
= self
._cursor
[b
] * 2
525 ch1
= char2id(self
._text
[char_idx
])
526 if self
._text
_size
_in
_chars
- 1 == char_idx
:
529 ch2
= char2id(self
._text
[char_idx
+ 1])
530 batch
[b
] = ch1
* vocabulary_size
+ ch2
531 self
._cursor
[b
] = (self
._cursor
[b
] + 1) % self
._text
_size
535 batches
= [self
._last
_batch
]
536 for step
in range(self
._num
_unrollings
):
537 batches
.append(self
._next
_batch
())
538 self
._last
_batch
= batches
[-1]
542 def bi2str(encoding
):
543 return id2char(encoding
// vocabulary_size
) + id2char(encoding
% vocabulary_size
)
546 def bigrams(encodings
):
547 return [bi2str(e
) for e
in encodings
]
550 def bibatches2string(batches
):
551 s
= [''] * batches
[0].shape
[0]
553 s
= [''.join(x
) for x
in zip(s
, bigrams(b
))]
557 bi_onehot
= np
.zeros((bigram_vocabulary_size
, bigram_vocabulary_size
))
558 np
.fill_diagonal(bi_onehot
, 1)
561 def bi_one_hot(encodings
):
562 return [bi_onehot
[e
] for e
in encodings
]
565 train_batches
= BigramBatchGenerator(train_text
, 8, 8)
566 valid_batches
= BigramBatchGenerator(valid_text
, 1, 1)
568 print(bibatches2string(train_batches
.next()))
569 print(bibatches2string(train_batches
.next()))
570 print(bibatches2string(valid_batches
.next()))
571 print(bibatches2string(valid_batches
.next()))
574 def logprob(predictions
, labels
):
575 """Log-probability of the true labels in a predicted batch."""
576 predictions
[predictions
< 1e-10] = 1e-10
577 return np
.sum(np
.multiply(labels
, -np
.log(predictions
))) / labels
.shape
[0]
580 def sample_distribution(distribution
):
581 """Sample one element from a distribution assumed to be an array of normalized
584 r
= random
.uniform(0, 1)
586 for i
in range(len(distribution
)):
590 return len(distribution
) - 1
593 def sample(prediction
, size
=vocabulary_size
):
594 """Turn a (column) prediction into 1-hot encoded samples."""
595 p
= np
.zeros(shape
=[1, size
], dtype
=np
.float)
596 p
[0, sample_distribution(prediction
[0])] = 1.0
600 def one_hot_voc(prediction
, size
=vocabulary_size
):
601 p
= np
.zeros(shape
=[1, size
], dtype
=np
.float)
602 p
[0, prediction
[0]] = 1.0
606 def random_distribution(size
=vocabulary_size
):
607 """Generate a random column of probabilities."""
608 b
= np
.random
.uniform(0.0, 1.0, size
=[1, size
])
609 return b
/ np
.sum(b
, 1)[:, None]
617 with graph
.as_default():
619 x
= tf
.Variable(tf
.truncated_normal([embedding_size
, num_nodes
* 4], -0.1, 0.1), name
='x')
620 # memory of all gates
621 m
= tf
.Variable(tf
.truncated_normal([num_nodes
, num_nodes
* 4], -0.1, 0.1), name
='m')
623 biases
= tf
.Variable(tf
.zeros([1, num_nodes
* 4]))
624 # Variables saving state across unrollings.
625 saved_output
= tf
.Variable(tf
.zeros([batch_size
, num_nodes
]), trainable
=False)
626 saved_state
= tf
.Variable(tf
.zeros([batch_size
, num_nodes
]), trainable
=False)
627 # Classifier weights and biases.
628 w
= tf
.Variable(tf
.truncated_normal([num_nodes
, bigram_vocabulary_size
], -0.1, 0.1))
629 b
= tf
.Variable(tf
.zeros([bigram_vocabulary_size
]))
630 # embeddings for all possible bigrams
631 embeddings
= tf
.Variable(tf
.random_uniform([bigram_vocabulary_size
, embedding_size
], -1.0, 1.0))
632 # one hot encoding for labels in
633 np_one_hot
= np
.zeros((bigram_vocabulary_size
, bigram_vocabulary_size
))
634 np
.fill_diagonal(np_one_hot
, 1)
635 bigram_one_hot
= tf
.constant(np
.reshape(np_one_hot
, -1), dtype
=tf
.float32
,
636 shape
=[bigram_vocabulary_size
, bigram_vocabulary_size
])
637 keep_prob
= tf
.placeholder(tf
.float32
)
640 # Definition of the cell computation.
641 def lstm_cell(i
, o
, state
):
642 i
= tf
.nn
.dropout(i
, keep_prob
)
643 mult
= tf
.matmul(i
, x
) + tf
.matmul(o
, m
) + biases
644 input_gate
= tf
.sigmoid(mult
[:, :num_nodes
])
645 forget_gate
= tf
.sigmoid(mult
[:, num_nodes
:num_nodes
* 2])
646 update
= mult
[:, num_nodes
* 3:num_nodes
* 4]
647 state
= forget_gate
* state
+ input_gate
* tf
.tanh(update
)
648 output_gate
= tf
.sigmoid(mult
[:, num_nodes
* 3:])
649 output
= tf
.nn
.dropout(output_gate
* tf
.tanh(state
), keep_prob
)
653 # Input data. [num_unrollings, batch_size] -> one hot encoding removed, we send just bigram ids
654 tf_train_data
= tf
.placeholder(tf
.int32
, shape
=[num_unrollings
+ 1, batch_size
])
656 for i
in tf
.split(0, num_unrollings
+ 1, tf_train_data
):
657 train_data
.append(tf
.squeeze(i
))
658 train_inputs
= train_data
[:num_unrollings
]
659 train_labels
= list()
660 for l
in train_data
[1:]:
661 train_labels
.append(tf
.gather(bigram_one_hot
, l
))
663 # Unrolled LSTM loop.
665 output
= saved_output
667 # python loop used: tensorflow does not support sequential operations yet
668 for i
in train_inputs
: # having a loop simulates having time
669 # embed input bigrams -> [batch_size, embedding_size]
670 output
, state
= lstm_cell(tf
.nn
.embedding_lookup(embeddings
, i
), output
, state
)
671 outputs
.append(output
)
673 # State saving across unrollings, control_dependencies makes sure that output and state are computed
674 with tf
.control_dependencies([saved_output
.assign(output
), saved_state
.assign(state
)]):
675 logits
= tf
.nn
.xw_plus_b(tf
.concat(0, outputs
), w
, b
)
676 loss
= tf
.reduce_mean(tf
.nn
.softmax_cross_entropy_with_logits(logits
,
677 tf
.concat(0, train_labels
)
680 global_step
= tf
.Variable(0)
681 learning_rate
= tf
.train
.exponential_decay(10.0, global_step
, 500, 0.9, staircase
=True)
682 optimizer
= tf
.train
.GradientDescentOptimizer(learning_rate
)
683 gradients
, v
= zip(*optimizer
.compute_gradients(loss
))
684 gradients
, _
= tf
.clip_by_global_norm(gradients
, 1.25)
685 optimizer
= optimizer
.apply_gradients(zip(gradients
, v
), global_step
=global_step
)
687 # here we predict the embedding
688 # train_prediction = tf.argmax(tf.nn.softmax(logits), 1, name='train_prediction')
689 train_prediction
= tf
.nn
.softmax(logits
)
691 # Sampling and validation eval: batch 1, no unrolling.
692 sample_input
= tf
.placeholder(tf
.int32
, shape
=[1])
693 saved_sample_output
= tf
.Variable(tf
.zeros([1, num_nodes
]))
694 saved_sample_state
= tf
.Variable(tf
.zeros([1, num_nodes
]))
695 reset_sample_state
= tf
.group(saved_sample_output
.assign(tf
.zeros([1, num_nodes
])),
696 saved_sample_state
.assign(tf
.zeros([1, num_nodes
])))
697 embed_sample_input
= tf
.nn
.embedding_lookup(embeddings
, sample_input
)
698 sample_output
, sample_state
= lstm_cell(embed_sample_input
, saved_sample_output
, saved_sample_state
)
700 with tf
.control_dependencies([saved_sample_output
.assign(sample_output
),
701 saved_sample_state
.assign(sample_state
)]):
702 sample_prediction
= tf
.nn
.softmax(tf
.nn
.xw_plus_b(sample_output
, w
, b
))
705 summary_frequency
= 100
706 # initalize batch generators
708 with tf
.Session(graph
=graph
) as session
:
709 tf
.initialize_all_variables().run()
711 train_batches
= BigramBatchGenerator(train_text
, batch_size
, num_unrollings
)
712 valid_batches
= BigramBatchGenerator(valid_text
, 1, 1)
714 for step
in range(num_steps
):
715 batches
= train_batches
.next()
716 _
, l
, lr
, predictions
= session
.run([optimizer
, loss
, learning_rate
, train_prediction
],
717 feed_dict
={tf_train_data
: batches
, keep_prob
: 0.6})
719 if step
% summary_frequency
== 0:
721 mean_loss
= mean_loss
/ summary_frequency
722 # The mean loss is an estimate of the loss over the last few batches.
723 print('Average loss at step %d: %f learning rate: %f' % (step
, mean_loss
, lr
))
725 labels
= list(batches
)[1:]
726 labels
= np
.concatenate([bi_one_hot(l
) for l
in labels
])
727 print('Minibatch perplexity: %.2f' % float(np
.exp(logprob(predictions
, labels
))))
728 if step
% (summary_frequency
* 10) == 0:
729 # Generate some samples.
732 feed
= np
.argmax(sample(random_distribution(bigram_vocabulary_size
), bigram_vocabulary_size
))
733 sentence
= bi2str(feed
)
734 reset_sample_state
.run()
736 prediction
= sample_prediction
.eval({sample_input
: [feed
], keep_prob
: 1.0})
737 feed
= np
.argmax(sample(prediction
, bigram_vocabulary_size
))
738 sentence
+= bi2str(feed
)
741 # Measure validation set perplexity.
742 reset_sample_state
.run()
744 for _
in range(valid_size
):
745 b
= valid_batches
.next()
746 predictions
= sample_prediction
.eval({sample_input
: b
[0], keep_prob
: 1.0})
748 valid_logprob
= valid_logprob
+ logprob(predictions
, one_hot_voc(b
[1], bigram_vocabulary_size
))
749 print('Validation set perplexity: %.2f' % float(np
.exp(valid_logprob
/ valid_size
)))
757 # Write a sequence-to-sequence LSTM which mirrors all the words in a sentence. For example, if your input is:
759 # the quick brown fox
761 # the model should attempt to output:
763 # eht kciuq nworb xof
765 # Refer to the lecture on how to put together a sequence-to-sequence model, as well as [this article](http://arxiv.org/abs/1409.3215) for best practices.
768 from tensorflow
.models
.rnn
.translate
import seq2seq_model
775 class Seq2SeqBatchGenerator(object):
776 def __init__(self
, text
, batch_size
, num_unrollings
):
778 self
._text
_size
= len(text
)
779 self
._batch
_size
= batch_size
780 self
._num
_unrollings
= num_unrollings
781 segment
= self
._text
_size
// num_unrollings
782 self
._cursor
= [offset
* segment
for offset
in range(batch_size
)]
783 self
._last
_batch
= self
._next
_batch
(0)
785 def _next_batch(self
, step
):
786 """Generate a single batch from the current cursor position in the data."""
788 # print('text size', self._text_size)
789 for b
in range(self
._num
_unrollings
):
790 # print(self._cursor[step])
791 self
._cursor
[step
] %= self
._text
_size
792 batch
+= self
._text
[self
._cursor
[step
]]
793 self
._cursor
[step
] += 1
797 """Generate the next array of batches from the data. The array consists of
798 the last batch of the previous array, followed by num_unrollings new ones.
800 batches
= [self
._last
_batch
]
801 for step
in range(self
._batch
_size
):
802 batches
.append(self
._next
_batch
(step
))
803 self
._last
_batch
= batches
[-1]
807 def characters(probabilities
):
808 """Turn a 1-hot encoding or a probability distribution over the possible
809 characters back into its (most likely) character representation."""
810 return [id2char(c
) for c
in np
.argmax(probabilities
, 1)]
813 def ids(probabilities
):
814 """Turn a 1-hot encoding or a probability distribution over the possible
815 characters back into its (most likely) character representation."""
816 return [str(c
) for c
in np
.argmax(probabilities
, 1)]
819 def batches2id(batches
):
820 """Convert a sequence of batches back into their (most likely) string
822 s
= [''] * batches
[0].shape
[0]
824 s
= [''.join(x
) for x
in zip(s
, ids(b
))]
828 train_batches
= Seq2SeqBatchGenerator(train_text
, batch_size
, num_unrollings
)
829 valid_batches
= Seq2SeqBatchGenerator(valid_text
, 1, num_unrollings
)
833 temp
= forward
.split(' ')
835 for i
in range(len(temp
)):
836 backward
+= temp
[i
][::-1] + ' '
837 return list(map(lambda x
: char2id(x
), backward
[:-1]))
840 batches
= train_batches
.next()
842 batch_encs
= list(map(lambda x
: list(map(lambda y
: char2id(y
), list(x
))), batches
))
843 batch_decs
= list(map(lambda x
: rev_id(x
), batches
))
844 print('x=', ''.join([id2char(x
) for x
in batch_encs
[0]]))
845 print('y=', ''.join([id2char(x
) for x
in batch_decs
[0]]))
848 def create_model(forward_only
):
849 model
= seq2seq_model
.Seq2SeqModel(source_vocab_size
=vocabulary_size
,
850 target_vocab_size
=vocabulary_size
,
854 max_gradient_norm
=5.0,
855 batch_size
=batch_size
,
857 learning_rate_decay_factor
=0.9,
859 forward_only
=forward_only
)
863 with tf
.Session() as sess
:
864 model
= create_model(False)
865 sess
.run(tf
.initialize_all_variables())
868 # This is the training loop.
869 step_time
, loss
= 0.0, 0.0
875 for step
in range(1, num_steps
):
876 model
.batch_size
= batch_size
877 batches
= train_batches
.next()
879 batch_encs
= list(map(lambda x
: list(map(lambda y
: char2id(y
), list(x
))), batches
))
880 batch_decs
= list(map(lambda x
: rev_id(x
), batches
))
881 for i
in range(len(batch_encs
)):
882 train_sets
.append((batch_encs
[i
], batch_decs
[i
]))
884 # Get a batch and make a step.
885 encoder_inputs
, decoder_inputs
, target_weights
= model
.get_batch([train_sets
], 0)
886 _
, step_loss
, _
= model
.step(sess
, encoder_inputs
, decoder_inputs
, target_weights
, 0, False)
888 loss
+= step_loss
/ step_ckpt
890 # Once in a while, we save checkpoint, print statistics, and run evals.
891 if step
% step_ckpt
== 0:
892 # Print statistics for the previous epoch.
893 perplexity
= math
.exp(loss
) if loss
< 300 else float('inf')
894 print("global step %d learning rate %.4f perplexity "
895 "%.2f" % (model
.global_step
.eval(), model
.learning_rate
.eval(), perplexity
))
896 # Decrease learning rate if no improvement was seen over last 3 times.
897 if len(previous_losses
) > 2 and loss
> max(previous_losses
[-3:]):
898 sess
.run(model
.learning_rate_decay_op
)
899 previous_losses
.append(loss
)
903 if step
% valid_ckpt
== 0:
907 batches
= ['the quick brown fox']
909 batch_encs
= list(map(lambda x
: list(map(lambda y
: char2id(y
), list(x
))), batches
))
910 # batch_decs = map(lambda x: rev_id(x), batches)
911 test_sets
.append((batch_encs
[0], []))
912 # Get a 1-element batch to feed the sentence to the model.
913 encoder_inputs
, decoder_inputs
, target_weights
= model
.get_batch([test_sets
], 0)
914 # Get output logits for the sentence.
915 _
, _
, output_logits
= model
.step(sess
, encoder_inputs
, decoder_inputs
, target_weights
, 0, True)
917 # This is a greedy decoder - outputs are just argmaxes of output_logits.
918 outputs
= [int(np
.argmax(logit
, axis
=1)) for logit
in output_logits
]
920 print('>>>>>>>>> ', batches
[0], ' -> ', ''.join(map(lambda x
: id2char(x
), outputs
)))
922 for _
in range(valid_size
):
924 v_batches
= valid_batches
.next()
926 v_batch_encs
= list(map(lambda x
: list(map(lambda y
: char2id(y
), list(x
))), v_batches
))
927 v_batch_decs
= list(map(lambda x
: rev_id(x
), v_batches
))
928 for i
in range(len(v_batch_encs
)):
929 valid_sets
.append((v_batch_encs
[i
], v_batch_decs
[i
]))
930 encoder_inputs
, decoder_inputs
, target_weights
= model
.get_batch([valid_sets
], 0)
931 _
, eval_loss
, _
= model
.step(sess
, encoder_inputs
, decoder_inputs
, target_weights
, 0, True)
932 v_loss
+= eval_loss
/ valid_size
934 eval_ppx
= math
.exp(v_loss
) if v_loss
< 300 else float('inf')
935 print(" valid eval: perplexity %.2f" % (eval_ppx
))
937 # reuse variable -> subdivide into two boxes
938 model
.batch_size
= 1 # We decode one sentence at a time.
939 batches
= ['the quick brown fox']
941 batch_encs
= list(map(lambda x
: list(map(lambda y
: char2id(y
), list(x
))), batches
))
942 # batch_decs = map(lambda x: rev_id(x), batches)
943 test_sets
.append((batch_encs
[0], []))
944 # Get a 1-element batch to feed the sentence to the model.
945 encoder_inputs
, decoder_inputs
, target_weights
= model
.get_batch([test_sets
], 0)
946 # Get output logits for the sentence.
947 _
, _
, output_logits
= model
.step(sess
, encoder_inputs
, decoder_inputs
, target_weights
, 0, True)
948 # This is a greedy decoder - outputs are just argmaxes of output_logits.
949 outputs
= [int(np
.argmax(logit
, axis
=1)) for logit
in output_logits
]
950 print('## : ', outputs
)
951 # If there is an EOS symbol in outputs, cut them at that point.
952 if char2id('!') in outputs
:
953 outputs
= outputs
[:outputs
.index(char2id('!'))]
955 print(batches
[0], ' -> ', ''.join(map(lambda x
: id2char(x
), outputs
)))